{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "import IPython\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import soundfile as sf\n",
    "import time\n",
    "\n",
    "from nara_wpe.ntt_wpe import ntt_wrapper as wpe\n",
    "from nara_wpe import project_root\n",
    "from nara_wpe.utils import stft, istft"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Minimal example with random data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def aquire_audio_data():\n",
    "    D, T = 4, 10000\n",
    "    y = np.random.normal(size=(D, T))\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = aquire_audio_data()\n",
    "\n",
    "start = time.perf_counter()\n",
    "x = wpe(y)\n",
    "end = time.perf_counter()\n",
    "\n",
    "print(f\"Time: {end-start}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example with real audio recordings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal.\n",
    "\n",
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "channels = 8\n",
    "sampling_rate = 16000\n",
    "delay = 3\n",
    "iterations = 5\n",
    "taps = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Audio data\n",
    "Shape: (frames, channels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n",
    "signal_list = [\n",
    "    sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n",
    "    for d in range(channels)\n",
    "]\n",
    "y = np.stack(signal_list, axis=0)\n",
    "IPython.display.Audio(y[0], rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### iterative WPE\n",
    "The wpe function is fed with y. The STFT and ISTFT is included in the Matlab package of NTT. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = wpe(y, iterations=iterations)\n",
    "IPython.display.Audio(x[0], rate=sampling_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Power spectrum \n",
    "Before and after applying NTT WPE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stft_options = dict(\n",
    "    size=512,\n",
    "    shift=128,\n",
    "    window_length=None,\n",
    "    fading=True,\n",
    "    pad=True,\n",
    "    symmetric_window=False\n",
    ")\n",
    "Y = stft(y, **stft_options).transpose(2, 0, 1)\n",
    "X = stft(x, **stft_options).transpose(2, 0, 1)\n",
    "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))\n",
    "im1 = ax1.imshow(20 * np.log10(np.abs(Y[ :, 0, 200:400])), origin='lower')\n",
    "ax1.set_xlabel('frames')\n",
    "_ = ax1.set_title('reverberated')\n",
    "im2 = ax2.imshow(20 * np.log10(np.abs(X[ :, 0, 200:400])), origin='lower', vmin=-120, vmax=0)\n",
    "ax2.set_xlabel('frames')\n",
    "_ = ax2.set_title('dereverberated')\n",
    "cb = fig.colorbar(im2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py36",
   "language": "python",
   "name": "py36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
